package top.lingkang.hibernate5.config;

import org.hibernate.Session;
import org.hibernate.SessionFactory;
import org.hibernate.StatelessSession;
import top.lingkang.hibernate5.transaction.HibernateTranInterceptor;

import javax.persistence.EntityManager;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

/**
 * @author lingkang
 * Created by 2023/12/2
 */
public class SessionFactoryInvocationHandler implements InvocationHandler {
    private SessionFactory sessionFactory;
    public static ThreadLocal<List<Object>> threadLocal = ThreadLocal.withInitial(ArrayList::new);

    public SessionFactoryInvocationHandler(SessionFactory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        if (method.getName().startsWith("openSession") || method.getName().startsWith("getCurrent")) {
            if (HibernateTranInterceptor.tranNumber.get() != null) {// 开启了事务
                Session session = get(Session.class);
                if (session == null) {
                    session = sessionFactory.openSession();
                    session.getTransaction().begin();
                    set(session);
                }
                return session;
            }
        } else if (method.getName().startsWith("createEn")) {
            if (HibernateTranInterceptor.tranNumber.get() != null) {// 开启了事务
                EntityManager entityManager = get(EntityManager.class);
                if (entityManager == null) {
                    entityManager = (EntityManager) method.invoke(sessionFactory, args);
                    entityManager.getTransaction().begin();
                    set(entityManager);
                }
                return entityManager;
            }
        } else if (method.getName().startsWith("openSta")) {
            if (HibernateTranInterceptor.tranNumber.get() != null) {// 开启了事务
                StatelessSession statelessSession = get(StatelessSession.class);
                if (statelessSession == null) {
                    statelessSession = (StatelessSession) method.invoke(sessionFactory, args);
                    statelessSession.getTransaction().begin();
                    set(statelessSession);
                }
                return statelessSession;
            }
        }

        return method.invoke(sessionFactory, args);
    }

    private <T> T get(Class<T> clazz) {
        for (Object o : threadLocal.get()) {
            if (clazz.isAssignableFrom(o.getClass()))
                return (T) o;
        }
        return null;
    }

    private void set(Object o) {
        threadLocal.get().add(o);
    }
}
